package com.luis.toolsuite.kmeans;

import com.luis.toolsuite.kmeans.model.Cluster;
import com.luis.toolsuite.util.KMeansUtil;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;

public class Kmeans {

    private final int k; // 分类数量
    private final int evalMethod ; //收敛方法，1：质心；2：簇内平方和
    private List<Double[]> dataset;//数据集

    private List<Double[]> bakCenterList;
    private List<Double[]> centerList; //质心的集合
    private List<Cluster> clusterList; //簇的集合
    private BigDecimal totalSSE; // 总的簇内平方和
    private int currentIteration;

    //step 0 确定分类数量k
    public Kmeans(int _k, int _m, List<Double[]> ds){
        this.k = _k;
        this.evalMethod = _m;
        this.dataset = ds;
        Collections.shuffle(dataset, new Random(KMeansUtil.SEED));
        bakCenterList = new ArrayList<>(k);
        centerList = new ArrayList<>(k);
        clusterList = new ArrayList<>(k);
        totalSSE = BigDecimal.ZERO;
        currentIteration = 0;
    }

    public void train(){
        getInitialCenter().calulateCluster();
        while(!evaluate()){
            currentIteration ++;
            recalculateCenter().calulateCluster();
        }
    }

    //step 1 随机得到初始质心
    public Kmeans getInitialCenter(){
        centerList.clear();
        centerList.addAll(dataset.subList(0,k));
        return this;
    }

    //step 2 计算初始质心对应簇和簇内平方和
    public Kmeans calulateCluster(){
        clusterList.clear();
        clusterList.addAll(centerList.stream().map(Cluster::new).collect(Collectors.toList()));
        dataset.forEach(this::matchCluster);
        clusterList.parallelStream().forEach(Cluster::calculateSSE);
        //累积SSE
        this.totalSSE = clusterList.stream().map(Cluster::getSse).reduce(BigDecimal.ZERO, BigDecimal::add);
        return this;
    }

    //找到当前点归属的聚簇
    private void matchCluster(Double[] point){
        BigDecimal distance = null;
        Cluster target = null;
        for(Cluster cluster : clusterList){
            Double[] center = cluster.getCenter();
            BigDecimal thisDistance= KMeansUtil.calculateDistance(point, center);
            if(distance == null || thisDistance.compareTo(distance) < 0){
                distance = thisDistance;
                target = cluster;
            }
        }
        if(null != target) {
            target.addPoint(point);
        }
    }

    //step 3 使用均值法得到新的质心
    public Kmeans recalculateCenter(){
        bakCenterList.clear();
        bakCenterList.addAll(centerList);
        centerList.clear();
        centerList.addAll(clusterList.parallelStream().map(cluster-> KMeansUtil.calculateMeanCenter(cluster.getPointList()))
                .collect(Collectors.toList()));
        return this;
    }

    //step 4 评估收敛条件
    //  a 质心不再变化
    //  b 簇内平方和变化率
    //  c 迭代次数，作为兜底的收敛条件
    public boolean evaluate(){
        if(currentIteration >= KMeansUtil.ITERATION_NUM) return true;
        if(evalMethod == 1){
            //根据质心判断，要求所有质心不在变化
            return KMeansUtil.centerEqual(bakCenterList,centerList);
        }else if(evalMethod == 2){
            return false;
        }else{
            return false;
        }

    }
    public int getK() {
        return k;
    }

    public List<Cluster> getClusterList() {
        return clusterList;
    }

    public void setClusterList(List<Cluster> clusterList) {
        this.clusterList = clusterList;
    }

    public BigDecimal getTotalSSE() {
        return totalSSE;
    }


    public List<Double[]> getCenterList() {
        return centerList;
    }

    public int getCurrentIteration() {
        return currentIteration;
    }
}
